/**
 * 2017年5月19日
 */
package cn.edu.bjtu.workbench.datasource.dsiter;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.stream.Collectors;

import org.datavec.api.writable.Writable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import cn.edu.bjtu.workbench.datasource.fileiter.TransformedVecWithIdLineVectorRecordReader;

/**
 * @author Alex
 *
 */
public class TransformedVecDataSetIterator implements DataSetIterator {
	private static final int CLASS_NUM = 15;
	private static final long serialVersionUID = -4574233893416571326L;
	private INDArrayDataSetIterator iter = null;
	public TransformedVecDataSetIterator(TransformedVecWithIdLineVectorRecordReader rr,int batchSize) {
		List<Pair<INDArray, INDArray>> ds = new ArrayList<>(10000);
		while(rr.hasNext()){
			//保证rr第一个是类标,第二个是特征
			List<Writable> list = rr.next();
			double d = list.get(0).toDouble();
			List<Double> l  = Arrays.asList(list.get(1).toString().split(" ")).stream().map(Double::parseDouble).collect(Collectors.toList());
			double arr[] = new double[l.size()];
			for(int i=0;i<arr.length;i++){
				arr[i] = l.get(i);
			}
			double []labels = new double[CLASS_NUM];
			labels[((int)d)-1] = 1;
			INDArray first = Nd4j.create(arr);
			INDArray second =  Nd4j.create(labels);
			//有异常数据,全是0,要排除,等过滤完语言可以删除这些
			if(first.getDouble(0) == 0.0){
				continue;
			}
		
			Pair<INDArray,INDArray> element = new Pair<INDArray, INDArray>(first, second);
			ds.add(element);
		}
		iter = new INDArrayDataSetIterator(ds, batchSize);
		
		
	}
	
	public static void main(String[] args) {
		INDArray vec =  Nd4j.create(new double[]{3.0,4.0});
		System.out.println(vec.rows());
		System.out.println(vec.columns());
		System.out.println(vec.size(0));
//		System.out.println(vec.size(1));
		System.out.println(vec.getDouble(0,0));
	}
	/**
	 * @return
	 * @see java.lang.Object#hashCode()
	 */
	public int hashCode() {
		return iter.hashCode();
	}
	/**
	 * @param num
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#next(int)
	 */
	public DataSet next(int num) {
		return iter.next(num);
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#totalExamples()
	 */
	public int totalExamples() {
		return iter.totalExamples();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#inputColumns()
	 */
	public int inputColumns() {
		return iter.inputColumns();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#totalOutcomes()
	 */
	public int totalOutcomes() {
		return iter.totalOutcomes();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#resetSupported()
	 */
	public boolean resetSupported() {
		return iter.resetSupported();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#asyncSupported()
	 */
	public boolean asyncSupported() {
		return iter.asyncSupported();
	}
	/**
	 * 
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#reset()
	 */
	public void reset() {
		iter.reset();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#batch()
	 */
	public int batch() {
		return iter.batch();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#cursor()
	 */
	public int cursor() {
		return iter.cursor();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#numExamples()
	 */
	public int numExamples() {
		return iter.numExamples();
	}
	/**
	 * @param preProcessor
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#setPreProcessor(org.nd4j.linalg.dataset.api.DataSetPreProcessor)
	 */
	public void setPreProcessor(DataSetPreProcessor preProcessor) {
		iter.setPreProcessor(preProcessor);
	}
	/**
	 * @param obj
	 * @return
	 * @see java.lang.Object#equals(java.lang.Object)
	 */
	public boolean equals(Object obj) {
		return iter.equals(obj);
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#getLabels()
	 */
	public List<String> getLabels() {
		return iter.getLabels();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#hasNext()
	 */
	public boolean hasNext() {
		return iter.hasNext();
	}
	/**
	 * @return
	 * @throws NoSuchElementException
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#next()
	 */
	public DataSet next() throws NoSuchElementException {
		return iter.next();
	}
	/**
	 * 
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#remove()
	 */
	public void remove() {
		iter.remove();
	}
	/**
	 * @return
	 * @see org.deeplearning4j.datasets.iterator.AbstractDataSetIterator#getPreProcessor()
	 */
	public DataSetPreProcessor getPreProcessor() {
		return iter.getPreProcessor();
	}
	/**
	 * @return
	 * @see java.lang.Object#toString()
	 */
	public String toString() {
		return iter.toString();
	}
	/**
	 * @param iterable
	 * @param batchSize
	 */

	/**
	 * 
	 */
	

}
